idx-ubyte 文件格式

idx-ubyte 是一种很简单的二进制文件格式,著名的 MNIST 使用的就是该格式。

它由一个 magic-number 和各个维度的长度组成 header,然后是主体数据。magic-number 和维度的长度都是 32 位大端无符号整数。

  • idx1-ubyte 的数据有一个维度,magic-number 的值为 0x00000801
  • idx3-ubyte 的数据有三个维度,magic-number 的值为 0x00000803
1struct Idx1Ubyte
2{
3    uint32_t magicNumber; 
4    uint32_t dim1;
5    uint8_t datas[];
6};
7
8struct Idx3Ubyte
9{
10    uint32_t magicNumber; 
11    uint32_t dim1;
12    uint32_t dim2;
13    uint32_t dim3;
14    uint8_t datas[];
15};

以 MNIST 为例 :

  • train-images.idx3-ubyte 是训练集图片
    • 维度 1 的值是 6000,表示包含 6000 张图片
    • 维度 2 的值是 28,表示一张图片有 28 行像素
    • 维度 3 的值是 28,表示一张图片有 28 列像素
  • train-labels.idx1-ubyte 时训练集标注
    • 维度 1 的值是 6000,表示包含 6000 个标注
1#ifndef IDX_UBYTE_HPP
2#define IDX_UBYTE_HPP
3
4#include <cstdio>
5#include <cstdint>
6#include <cstring>
7#include <cerrno>
8#include <vector>
9
10template<uint8_t N>
11struct IdxUbyteData
12{
13    uint8_t* data =  nullptr;
14    uint32_t dims[N];
15
16    IdxUbyteData() noexcept = default;
17
18    ~IdxUbyteData() noexcept
19    {
20        if (data != nullptr)
21        {
22            delete[] data;
23            data = nullptr;
24        }
25    }
26
27    IdxUbyteData(IdxUbyteData&& src) noexcept
28    {
29        data = src.data;
30        src.data = nullptr;
31        memcpy(dims, src.dims, sizeof(dims));
32    }
33
34    IdxUbyteData(const IdxUbyteData& src) noexcept
35    {
36        memcpy(dims, src.dims, sizeof(dims));
37
38        size_t bytes = 1;
39        for (uint32_t i = 0; i < N; i++)
40        {
41            bytes *= dims[i];
42        }
43
44        data = new uint8_t[bytes];
45        memcpy(data, src.data, bytes);
46    }
47};
48
49template<uint8_t N>
50class IdxUbyte
51{
52public:
53    IdxUbyte() noexcept = default;
54    ~IdxUbyte() noexcept = default;
55
56    bool write(const char* file, const std::vector< IdxUbyteData<N-1> >& dataset) const noexcept
57    {
58        if (dataset.size() == 0)
59            return false;
60
61        FILE* fp = fopen(file, "wb");
62        if (fp == nullptr)
63        {
64            fprintf(stderr, "%s\n", strerror(errno));
65            return false;
66        }
67
68        this->m_write<32>(fp, MagicNumber);
69        this->m_write<32>(fp, dataset.size());
70
71        size_t bytes = 1;
72        for (uint32_t i = 0; i < N-1; i++)
73        {
74            this->m_write<32>(fp, dataset[0].dims[i]);
75            bytes *= dataset[0].dims[i];
76        }
77
78        for (const auto& data : dataset)
79        {
80            if (fwrite(data.data, 1, bytes, fp) < bytes)
81            {
82                fprintf(stderr, "%s\n", strerror(errno));
83            }
84        }
85
86        fclose(fp);
87        return true;
88    }
89
90    std::vector< IdxUbyteData<N-1> > read(const char* file) const noexcept
91    {
92        std::vector< IdxUbyteData<N-1> > ret(0);
93
94        FILE* fp = fopen(file, "rb");
95        if (fp == nullptr)
96        {
97            fprintf(stderr, "%s\n", strerror(errno));
98            return ret;
99        }
100
101        uint32_t magic = this->m_read<32>(fp);
102        if (magic != MagicNumber)
103        {
104            fprintf(stderr, "magic number mismatch: 0x%08x != 0x%08x\n", magic, MagicNumber);
105            fclose(fp);
106            return ret;
107        }
108
109        uint32_t dims[N];
110        for (size_t i = 0; i < N; i++)
111        {
112            dims[i] = this->m_read<32>(fp);
113            printf("dim %zu: %u\n", i, dims[i]);
114        }
115
116        for (uint32_t i = 0; i < dims[0]; i++)
117        {
118            size_t bytes = 1;
119            IdxUbyteData<N-1>& data = ret.emplace_back();
120            for (size_t j = 1; j < N; j++)
121            {
122                data.dims[j-1] = dims[j];
123                bytes *= dims[j];
124            }
125
126            data.data = new uint8_t[bytes];
127            if (fread(data.data, 1, bytes, fp) < bytes)
128            {
129                fprintf(stderr, "%s\n", strerror(errno));
130            }
131        }
132
133        fclose(fp);
134        return ret;
135    }
136
137private:
138    constexpr static const uint32_t MagicNumber = 0x00000800 | N;
139
140    // 大端读
141    template<size_t bits>
142    uint32_t m_read(FILE* fp) const noexcept
143    {
144        uint32_t ret = 0;
145        uint8_t byte = 0;
146
147        for (size_t i = 0; i < bits / 8; i++)
148        {
149            ret <<= 8;
150            if (fread(&byte, 1, 1, fp) < 1)
151            {
152                fprintf(stderr, "%s\n", strerror(errno));
153            }
154            ret |= byte;
155        }
156
157        return ret;
158    }
159
160    // 大端写
161    template<size_t bits>
162    void m_write(FILE* fp, uintmax_t value) const noexcept
163    {
164        constexpr const size_t bytes = bits / 8;
165        uint8_t byte = 0;
166
167        for (size_t i = 1; i <= bytes; i++)
168        {
169            byte = static_cast<uint8_t>(value >> (8 * (bytes - i)));
170            fwrite(&byte, 1, 1, fp);
171        }
172    }
173};
174
175#endif // IDX_UBYTE_HPP